{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MAGNeT\n",
    "Welcome to MAGNeT's demo jupyter notebook. \n",
    "Here you will find a self-contained example of how to use MAGNeT for music/sound-effect generation.\n",
    "\n",
    "First, we start by initializing MAGNeT for music generation, you can choose a model from the following selection:\n",
    "1. facebook/magnet-small-10secs - a 300M non-autoregressive transformer capable of generating 10-second music conditioned on text.\n",
    "2. facebook/magnet-medium-10secs - 1.5B parameters, 10 seconds music samples.\n",
    "3. facebook/magnet-small-30secs - 300M parameters, 30 seconds music samples.\n",
    "4. facebook/magnet-medium-30secs - 1.5B parameters, 30 seconds music samples.\n",
    "\n",
    "We will use the `facebook/magnet-small-10secs` variant for the purpose of this demonstration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from audiocraft.models import MAGNeT\n",
    "\n",
    "model = MAGNeT.get_pretrained('facebook/magnet-small-10secs')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let us configure the generation parameters. Specifically, you can control the following:\n",
    "* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.\n",
    "* `top_k` (int, optional): top_k used for sampling. Defaults to 0.\n",
    "* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9.\n",
    "* `temperature` (float, optional): Initial softmax temperature parameter. Defaults to 3.0.\n",
    "* `max_clsfg_coef` (float, optional): Initial coefficient used for classifier free guidance. Defaults to 10.0.\n",
    "* `min_clsfg_coef` (float, optional): Final coefficient used for classifier free guidance. Defaults to 1.0.\n",
    "* `decoding_steps` (list of n_q ints, optional): The number of iterative decoding steps, for each of the n_q RVQ codebooks.\n",
    "* `span_arrangement` (str, optional): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1') \n",
    "                                      in the masking scheme. \n",
    "\n",
    "When left unchanged, MAGNeT will revert to its default parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.set_generation_params(\n",
    "    use_sampling=True,\n",
    "    top_k=0,\n",
    "    top_p=0.9,\n",
    "    temperature=3.0,\n",
    "    max_cfg_coef=10.0,\n",
    "    min_cfg_coef=1.0,\n",
    "    decoding_steps=[int(20 * model.lm.cfg.dataset.segment_duration // 10),  10, 10, 10],\n",
    "    span_arrangement='stride1'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we can go ahead and start generating music given textual prompts."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Text-conditional Generation - Music"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from audiocraft.utils.notebook import display_audio\n",
    "\n",
    "###### Text-to-music prompts - examples ######\n",
    "text = \"80s electronic track with melodic synthesizers, catchy beat and groovy bass\"\n",
    "# text = \"80s electronic track with melodic synthesizers, catchy beat and groovy bass. 170 bpm\"\n",
    "# text = \"Earthy tones, environmentally conscious, ukulele-infused, harmonic, breezy, easygoing, organic instrumentation, gentle grooves\"\n",
    "# text = \"Funky groove with electric piano playing blue chords rhythmically\"\n",
    "# text = \"Rock with saturated guitars, a heavy bass line and crazy drum break and fills.\"\n",
    "# text = \"A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle\"\n",
    "                   \n",
    "N_VARIATIONS = 3\n",
    "descriptions = [text for _ in range(N_VARIATIONS)]\n",
    "\n",
    "print(f\"text prompt: {text}\\n\")\n",
    "output = model.generate(descriptions=descriptions, progress=True, return_tokens=True)\n",
    "display_audio(output[0], sample_rate=model.compression_model.sample_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Text-conditional Generation - Sound Effects"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Besides music, MAGNeT models can generate sound effects given textual prompts. \n",
    "First, let's load an Audio-MAGNeT model, out of the following collection: \n",
    "1. facebook/audio-magnet-small - a 300M non-autoregressive transformer capable of generating 10 second sound effects conditioned on text.\n",
    "2. facebook/audio-magnet-medium - 10 second sound effect generation, 1.5B parameters.\n",
    "\n",
    "We will use the `facebook/audio-magnet-small` variant for the purpose of this demonstration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from audiocraft.models import MAGNeT\n",
    "\n",
    "model = MAGNeT.get_pretrained('facebook/audio-magnet-small')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The recommended parameters for sound generation are a bit different than the defaults in MAGNeT, let's initialize it: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.set_generation_params(\n",
    "    use_sampling=True,\n",
    "    top_k=0,\n",
    "    top_p=0.8,\n",
    "    temperature=3.5,\n",
    "    max_cfg_coef=20.0,\n",
    "    min_cfg_coef=1.0,\n",
    "    decoding_steps=[int(20 * model.lm.cfg.dataset.segment_duration // 10),  10, 10, 10],\n",
    "    span_arrangement='stride1'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we can go ahead and start generating sounds given textual prompts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from audiocraft.utils.notebook import display_audio\n",
    "               \n",
    "###### Text-to-audio prompts - examples ######\n",
    "text = \"Seagulls squawking as ocean waves crash while wind blows heavily into a microphone.\"\n",
    "# text = \"A toilet flushing as music is playing and a man is singing in the distance.\"\n",
    "\n",
    "N_VARIATIONS = 3\n",
    "descriptions = [text for _ in range(N_VARIATIONS)]\n",
    "\n",
    "print(f\"text prompt: {text}\\n\")\n",
    "output = model.generate(descriptions=descriptions, progress=True, return_tokens=True)\n",
    "display_audio(output[0], sample_rate=model.compression_model.sample_rate)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "vscode": {
   "interpreter": {
    "hash": "b02c911f9b3627d505ea4a19966a915ef21f28afb50dbf6b2115072d27c69103"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
